#include "opencl_source_map.hpp" 
namespace MNN { 
#ifndef MNN_OPENCL_BUFFER_CLOSED
#ifdef MNN_SUPPORT_INTEL_SUBGROUP
const char* winogradTransform_subgroup_buf = 
"#ifdef MNN_SUPPORT_FP16\n"
"#pragma OPENCL EXTENSION cl_khr_fp16 : enable\n"
"#endif\n"
"#define GLOBAL_SIZE_DIM2 "" __private int global_size_dim0,__private int global_size_dim1,\n"
"#define UNIFORM_BOUNDRY_CHECK(index0, index1) "" if(index0 >= global_size_dim0 || index1 >= global_size_dim1) { "" return; "" }\n"
" \n"
"#pragma OPENCL EXTENSION cl_intel_subgroups : enable\n"
"__attribute__((intel_reqd_sub_group_size(16)))\n"
"__kernel void winoTransSrcBuf2_3_1_c16_c16(GLOBAL_SIZE_DIM2\n"
" __global const FLOAT* uInput,// 0\n"
" __global FLOAT* uOutput,__private const int unitWidth,\n"
" __private const int unitHeight,// 3\n"
" __private const int padX,__private const int padY,\n"
" __private const int srcWidth,// 6\n"
" __private const int srcHeight,__private const int srcChannelC4,__private const int srcChannelC16,__private const int dstHeight,\n"
" __private const int batchOffset,\n"
" __private const int batch,\n"
" __private const int input_pad_left,__private const int input_pad_right) {\n"
" int2 pos=(int2)(get_global_id(0),get_global_id(1)); \n"
" UNIFORM_BOUNDRY_CHECK(pos.x,pos.y);\n"
" const int unitWidth_idx=pos.x % unitWidth;\n"
" const int unitHeight_idx=pos.x/unitWidth;\n"
" const int sglid=get_sub_group_local_id();\n"
" const int pos_y=get_group_id(1);\n"
" int2 realPos=(int2)(unitWidth_idx,unitHeight_idx);\n"
" int src_pitch=srcWidth+input_pad_left+input_pad_right;\n"
" \n"
" {\n"
" int sxStart=(realPos.x)*2-padX;\n"
" int syStart=(realPos.y)*2-padY;\n"
" FLOAT4 S[4];\n"
" \n"
" int inp_offset=(((batchOffset*srcChannelC16+pos_y)*srcHeight+syStart)*src_pitch+sxStart+input_pad_left)*16;\n"
" for(int i=0; i<4; ++i){\n"
" int sy=i+syStart;\n"
" if(sy<0 || sy >= srcHeight){\n"
" S[i]=(FLOAT4)0;\n"
" }else{\n"
"#ifdef MNN_SUPPORT_FP16\n"
" S[i]=as_half4(intel_sub_group_block_read_us4((__global ushort*)(uInput+inp_offset)));\n"
"#else\n"
" S[i]=as_float4(intel_sub_group_block_read4((__global uint*)(uInput+inp_offset)));\n"
"#endif\n"
" }\n"
" inp_offset += 16*src_pitch;\n"
" }\n"
" FLOAT m00=+S[0].s0-S[2].s0;\n"
" FLOAT m10=+S[0].s1-S[2].s1;\n"
" FLOAT m20=+S[0].s2-S[2].s2;\n"
" FLOAT m30=+S[0].s3-S[2].s3;\n"
" FLOAT m01=+(FLOAT)0.5f*S[1].s0+(FLOAT)0.5f*S[2].s0;\n"
" FLOAT m11=+(FLOAT)0.5f*S[1].s1+(FLOAT)0.5f*S[2].s1;\n"
" FLOAT m21=+(FLOAT)0.5f*S[1].s2+(FLOAT)0.5f*S[2].s2;\n"
" FLOAT m31=+(FLOAT)0.5f*S[1].s3+(FLOAT)0.5f*S[2].s3;\n"
" FLOAT m02=-(FLOAT)0.5f*S[1].s0+(FLOAT)0.5f*S[2].s0;\n"
" FLOAT m12=-(FLOAT)0.5f*S[1].s1+(FLOAT)0.5f*S[2].s1;\n"
" FLOAT m22=-(FLOAT)0.5f*S[1].s2+(FLOAT)0.5f*S[2].s2;\n"
" FLOAT m32=-(FLOAT)0.5f*S[1].s3+(FLOAT)0.5f*S[2].s3;\n"
" FLOAT m03=-S[1].s0+S[3].s0;\n"
" FLOAT m13=-S[1].s1+S[3].s1;\n"
" FLOAT m23=-S[1].s2+S[3].s2;\n"
" FLOAT m33=-S[1].s3+S[3].s3;\n"
" \n"
" //NC4HW4 [alpha*alpha,srcChannelC16,dstHeight,16]\n"
" //index: [0,pos.y/16,pos.x,0]\n"
" int out_offset=(pos_y*dstHeight+pos.x)*16+sglid;\n"
" int batch_offset=srcChannelC16*dstHeight*16;\n"
" uOutput[out_offset+0*batch_offset]=+m00-m20;\n"
" uOutput[out_offset+1*batch_offset]=+(FLOAT)0.5f*m10+(FLOAT)0.5f*m20;\n"
" uOutput[out_offset+2*batch_offset]=-(FLOAT)0.5f*m10+(FLOAT)0.5f*m20;\n"
" uOutput[out_offset+3*batch_offset]=-m10+m30;\n"
" uOutput[out_offset+4*batch_offset]=+m01-m21;\n"
" uOutput[out_offset+5*batch_offset]=+(FLOAT)0.5f*m11+(FLOAT)0.5f*m21;\n"
" uOutput[out_offset+6*batch_offset]=-(FLOAT)0.5f*m11+(FLOAT)0.5f*m21;\n"
" uOutput[out_offset+7*batch_offset]=-m11+m31;\n"
" uOutput[out_offset+8*batch_offset]=+m02-m22;\n"
" uOutput[out_offset+9*batch_offset]=+(FLOAT)0.5f*m12+(FLOAT)0.5f*m22;\n"
" uOutput[out_offset+10*batch_offset]=-(FLOAT)0.5f*m12+(FLOAT)0.5f*m22;\n"
" uOutput[out_offset+11*batch_offset]=-m12+m32;\n"
" uOutput[out_offset+12*batch_offset]=+m03-m23;\n"
" uOutput[out_offset+13*batch_offset]=+(FLOAT)0.5f*m13+(FLOAT)0.5f*m23;\n"
" uOutput[out_offset+14*batch_offset]=-(FLOAT)0.5f*m13+(FLOAT)0.5f*m23;\n"
" uOutput[out_offset+15*batch_offset]=-m13+m33;\n"
" }\n"
"}\n"
"__attribute__((intel_reqd_sub_group_size(16)))\n"
"__kernel void winoTransDstBuf2_3_1_c16_c16(GLOBAL_SIZE_DIM2\n"
" __global const FLOAT* uInput,\n"
" __global const FLOAT* uBias,\n"
" __global FLOAT* uOutput,\n"
" __private const int unitWidth,//wUnit\n"
" __private const int unitHeight,//hUnit\n"
" __private const int dstWidth,\n"
" __private const int dstHeight,\n"
" __private const int dstChannelC4,__private const int dstChannelC16,__private const int srcWidth,\n"
" __private const int batchOffset,\n"
" __private const int batch,\n"
" __private const int output_pad_left,__private const int output_pad_right) {\n"
" int2 pos=(int2)(get_global_id(0),get_global_id(1));\n"
" UNIFORM_BOUNDRY_CHECK(pos.x,pos.y);\n"
" const int unitWidth_idx=pos.x % unitWidth;\n"
" const int unitHeight_idx=pos.x/unitWidth; \n"
" const int sglid=get_sub_group_local_id();\n"
" const int pos_y=get_group_id(1);\n"
" int2 realPos=(int2)(unitWidth_idx,unitHeight_idx);\n"
" \n"
" FLOAT bias=uBias[pos.y];\n"
" {\n"
" int oyStart=realPos.y*2;\n"
" int oxStart=realPos.x*2;\n"
" \n"
" //NC4HW4 [alpha2,dstChannelC16,wUnit*hUnit,16]\n"
" //index: [0,pos.y/4,pos.x,pos.y%4]\n"
" const int inp_offset=(pos_y*srcWidth+pos.x)*16+sglid;\n"
" const int ic_offset=16*srcWidth*dstChannelC16;\n"
" FLOAT S00=uInput[inp_offset+ic_offset*0];\n"
" FLOAT S10=uInput[inp_offset+ic_offset*1];\n"
" FLOAT S20=uInput[inp_offset+ic_offset*2];\n"
" FLOAT S30=uInput[inp_offset+ic_offset*3];\n"
" FLOAT S01=uInput[inp_offset+ic_offset*4];\n"
" FLOAT S11=uInput[inp_offset+ic_offset*5];\n"
" FLOAT S21=uInput[inp_offset+ic_offset*6];\n"
" FLOAT S31=uInput[inp_offset+ic_offset*7];\n"
" FLOAT S02=uInput[inp_offset+ic_offset*8];\n"
" FLOAT S12=uInput[inp_offset+ic_offset*9];\n"
" FLOAT S22=uInput[inp_offset+ic_offset*10];\n"
" FLOAT S32=uInput[inp_offset+ic_offset*11];\n"
" FLOAT S03=uInput[inp_offset+ic_offset*12];\n"
" FLOAT S13=uInput[inp_offset+ic_offset*13];\n"
" FLOAT S23=uInput[inp_offset+ic_offset*14];\n"
" FLOAT S33=uInput[inp_offset+ic_offset*15];\n"
" FLOAT m00=+S00+S01+S02;\n"
" FLOAT m10=+S10+S11+S12;\n"
" FLOAT m20=+S20+S21+S22;\n"
" FLOAT m30=+S30+S31+S32;\n"
" FLOAT m01=+S01-S02+S03;\n"
" FLOAT m11=+S11-S12+S13;\n"
" FLOAT m21=+S21-S22+S23;\n"
" FLOAT m31=+S31-S32+S33;\n"
" \n"
" //NC4HW4 [batch,dstChannelC4,dstHeight,dstWidth]\n"
" //index: [batchOffset,pos.y,oyStart,oxStart]\n"
" int dst_pitch=dstWidth+output_pad_left+output_pad_right;\n"
" int out_offset=(((batchOffset*dstChannelC16+ pos_y)*dstHeight+oyStart)*dst_pitch+oxStart+output_pad_left)*16+sglid;\n"
" {\n"
" FLOAT2 res=(FLOAT2)(bias+m00+m10+m20,bias+m10-m20+m30);\n"
"#ifdef RELU\n"
" res=max(res,(FLOAT2)0);\n"
"#endif\n"
"#ifdef RELU6\n"
" res=clamp(res,(FLOAT2)0,(FLOAT2)6);\n"
"#endif\n"
"#if OUTPUT_LEFTOVERS\n"
" uOutput[out_offset]=res.x;\n"
" if(oxStart+1< dstWidth){\n"
" uOutput[out_offset+16]=res.y;\n"
" }\n"
"#else\n"
"#ifdef MNN_SUPPORT_FP16\n"
" intel_sub_group_block_write_us2((__global ushort*)(uOutput+out_offset),as_ushort2(res));\n"
"#else\n"
" intel_sub_group_block_write2((__global uint*)(uOutput+out_offset),as_uint2(res));\n"
"#endif\n"
"#endif //OUTPUT_LEFTOVERS\n"
" }\n"
" {\n"
" int oy=oyStart+1;\n"
" if (oy<dstHeight) {\n"
" FLOAT2 res=(FLOAT2)(bias+m01+m11+m21,bias+m11-m21+m31);\n"
"#ifdef RELU\n"
" res=max(res,(FLOAT2)0);\n"
"#endif\n"
"#ifdef RELU6\n"
" res=clamp(res,(FLOAT2)0,(FLOAT2)6);\n"
"#endif\n"
"#if OUTPUT_LEFTOVERS\n"
" uOutput[out_offset+16*dst_pitch]=res.x;\n"
" if(oxStart+1< dstWidth){\n"
" uOutput[out_offset+16+16*dst_pitch]=res.y;\n"
" }\n"
"#else\n"
"#ifdef MNN_SUPPORT_FP16\n"
" intel_sub_group_block_write_us2((__global ushort*)(uOutput+out_offset+16*dst_pitch),as_ushort2(res));\n"
"#else\n"
" intel_sub_group_block_write2((__global uint*)(uOutput+out_offset+16*dst_pitch),as_uint2(res));\n"
"#endif\n"
"#endif //OUTPUT_LEFTOVERS\n"
" }\n"
" }\n"
" if(unitWidth_idx == 0){\n"
" int pad_offset=(((batchOffset*dstChannelC16+ pos_y)*dstHeight+oyStart)*dst_pitch)*16+sglid;\n"
" for(int i=0; i<output_pad_left; ++i){\n"
" uOutput[pad_offset+i*16]=0;\n"
" uOutput[pad_offset+(i+dst_pitch)*16]=0;\n"
" }\n"
" }\n"
" if(unitWidth_idx == unitWidth-1){\n"
" int pad_offset=(((batchOffset*dstChannelC16+ pos_y)*dstHeight+oyStart)*dst_pitch+output_pad_left+dstWidth)*16+sglid;\n"
" for(int i=0; i<output_pad_right; ++i){\n"
" uOutput[pad_offset+i*16]=0;\n"
" uOutput[pad_offset+(i+dst_pitch)*16]=0;\n"
" }\n"
" }\n"
" }\n"
"}\n"
"__kernel void winoTransSrcBuf2_3_1_c4_c16(GLOBAL_SIZE_DIM2\n"
" __global const FLOAT* uInput,// 0\n"
" __global FLOAT* uOutput,__private const int unitWidth,\n"
" __private const int unitHeight,// 3\n"
" __private const int padX,__private const int padY,\n"
" __private const int srcWidth,// 6\n"
" __private const int srcHeight,__private const int srcChannelC4,__private const int srcChannelC16,__private const int dstHeight,\n"
" __private const int batchOffset,\n"
" __private const int batch,\n"
" __private const int input_pad_left,__private const int input_pad_right) {\n"
" int2 pos=(int2)(get_global_id(0),get_global_id(1)); \n"
" UNIFORM_BOUNDRY_CHECK(pos.x,pos.y);\n"
" int unitWidth_idx=pos.x % unitWidth;\n"
" int unitHeight_idx=pos.x/unitWidth;\n"
" int2 realPos=(int2)(unitWidth_idx,unitHeight_idx);\n"
" \n"
" {\n"
" int sxStart=(realPos.x)*2-padX;\n"
" int syStart=(realPos.y)*2-padY;\n"
" FLOAT4 S00;\n"
" FLOAT4 S10;\n"
" FLOAT4 S20;\n"
" FLOAT4 S30;\n"
" FLOAT4 S01;\n"
" FLOAT4 S11;\n"
" FLOAT4 S21;\n"
" FLOAT4 S31;\n"
" FLOAT4 S02;\n"
" FLOAT4 S12;\n"
" FLOAT4 S22;\n"
" FLOAT4 S32;\n"
" FLOAT4 S03;\n"
" FLOAT4 S13;\n"
" FLOAT4 S23;\n"
" FLOAT4 S33;\n"
" \n"
" int inp_offset=(((batchOffset+pos.y*batch)*srcHeight+syStart)*srcWidth+sxStart)*4;\n"
" {\n"
" int sx=0+sxStart;\n"
" int sy=0+syStart;\n"
" \n"
" bool outBound=(sx<0 || sx >= srcWidth || sy<0 || sy >= srcHeight);\n"
" S00=outBound ? (FLOAT4)(0) : vload4(0,uInput+inp_offset);\n"
" }\n"
" {\n"
" int sx=1+sxStart;\n"
" int sy=0+syStart;\n"
" \n"
" bool outBound=(sx<0 || sx >= srcWidth || sy<0 || sy >= srcHeight);\n"
" S10=outBound ? (FLOAT4)(0) : vload4(0,uInput+inp_offset+4);\n"
" }\n"
" {\n"
" int sx=2+sxStart;\n"
" int sy=0+syStart;\n"
" \n"
" bool outBound=(sx<0 || sx >= srcWidth || sy<0 || sy >= srcHeight);\n"
" S20=outBound ? (FLOAT4)(0) : vload4(0,uInput+inp_offset+8);\n"
" }\n"
" {\n"
" int sx=3+sxStart;\n"
" int sy=0+syStart;\n"
" bool outBound=(sx<0 || sx >= srcWidth || sy<0 || sy >= srcHeight);\n"
" S30=outBound ? (FLOAT4)(0) : vload4(0,uInput+inp_offset+12);\n"
" }\n"
" {\n"
" int sx=0+sxStart;\n"
" int sy=1+syStart;\n"
" bool outBound=(sx<0 || sx >= srcWidth || sy<0 || sy >= srcHeight);\n"
" S01=outBound ? (FLOAT4)(0) : vload4(0,uInput+inp_offset+4*srcWidth);\n"
" }\n"
" {\n"
" int sx=1+sxStart;\n"
" int sy=1+syStart;\n"
" bool outBound=(sx<0 || sx >= srcWidth || sy<0 || sy >= srcHeight);\n"
" S11=outBound ? (FLOAT4)(0) : vload4(0,uInput+inp_offset+4*srcWidth+4);\n"
" }\n"
" {\n"
" int sx=2+sxStart;\n"
" int sy=1+syStart;\n"
" bool outBound=(sx<0 || sx >= srcWidth || sy<0 || sy >= srcHeight);\n"
" S21=outBound ? (FLOAT4)(0) : vload4(0,uInput+inp_offset+4*srcWidth+8);\n"
" }\n"
" {\n"
" int sx=3+sxStart;\n"
" int sy=1+syStart;\n"
" bool outBound=(sx<0 || sx >= srcWidth || sy<0 || sy >= srcHeight);\n"
" S31=outBound ? (FLOAT4)(0) : vload4(0,uInput+inp_offset+4*srcWidth+12);\n"
" }\n"
" {\n"
" int sx=0+sxStart;\n"
" int sy=2+syStart;\n"
" bool outBound=(sx<0 || sx >= srcWidth || sy<0 || sy >= srcHeight);\n"
" S02=outBound ? (FLOAT4)(0) : vload4(0,uInput+inp_offset+8*srcWidth);\n"
" }\n"
" {\n"
" int sx=1+sxStart;\n"
" int sy=2+syStart;\n"
" bool outBound=(sx<0 || sx >= srcWidth || sy<0 || sy >= srcHeight);\n"
" S12=outBound ? (FLOAT4)(0) : vload4(0,uInput+inp_offset+8*srcWidth+4);\n"
" }\n"
" {\n"
" int sx=2+sxStart;\n"
" int sy=2+syStart;\n"
" bool outBound=(sx<0 || sx >= srcWidth || sy<0 || sy >= srcHeight);\n"
" S22=outBound ? (FLOAT4)(0) : vload4(0,uInput+inp_offset+8*srcWidth+8);\n"
" }\n"
" {\n"
" int sx=3+sxStart;\n"
" int sy=2+syStart;\n"
" bool outBound=(sx<0 || sx >= srcWidth || sy<0 || sy >= srcHeight);\n"
" S32=outBound ? (FLOAT4)(0) : vload4(0,uInput+inp_offset+8*srcWidth+12);\n"
" }\n"
" {\n"
" int sx=0+sxStart;\n"
" int sy=3+syStart;\n"
" bool outBound=(sx<0 || sx >= srcWidth || sy<0 || sy >= srcHeight);\n"
" S03=outBound ? (FLOAT4)(0) : vload4(0,uInput+inp_offset+12*srcWidth);\n"
" }\n"
" {\n"
" int sx=1+sxStart;\n"
" int sy=3+syStart;\n"
" bool outBound=(sx<0 || sx >= srcWidth || sy<0 || sy >= srcHeight);\n"
" S13=outBound ? (FLOAT4)(0) : vload4(0,uInput+inp_offset+12*srcWidth+4);\n"
" }\n"
" {\n"
" int sx=2+sxStart;\n"
" int sy=3+syStart;\n"
" bool outBound=(sx<0 || sx >= srcWidth || sy<0 || sy >= srcHeight);\n"
" S23=outBound ? (FLOAT4)(0) : vload4(0,uInput+inp_offset+12*srcWidth+8);\n"
" }\n"
" {\n"
" int sx=3+sxStart;\n"
" int sy=3+syStart;\n"
" bool outBound=(sx<0 || sx >= srcWidth || sy<0 || sy >= srcHeight);\n"
" S33=outBound ? (FLOAT4)(0) : vload4(0,uInput+inp_offset+12*srcWidth+12);\n"
" }\n"
" FLOAT4 m00=+S00-S02;\n"
" FLOAT4 m10=+S10-S12;\n"
" FLOAT4 m20=+S20-S22;\n"
" FLOAT4 m30=+S30-S32;\n"
" FLOAT4 m01=+(FLOAT)0.5f*S01+(FLOAT)0.5f*S02;\n"
" FLOAT4 m11=+(FLOAT)0.5f*S11+(FLOAT)0.5f*S12;\n"
" FLOAT4 m21=+(FLOAT)0.5f*S21+(FLOAT)0.5f*S22;\n"
" FLOAT4 m31=+(FLOAT)0.5f*S31+(FLOAT)0.5f*S32;\n"
" FLOAT4 m02=-(FLOAT)0.5f*S01+(FLOAT)0.5f*S02;\n"
" FLOAT4 m12=-(FLOAT)0.5f*S11+(FLOAT)0.5f*S12;\n"
" FLOAT4 m22=-(FLOAT)0.5f*S21+(FLOAT)0.5f*S22;\n"
" FLOAT4 m32=-(FLOAT)0.5f*S31+(FLOAT)0.5f*S32;\n"
" FLOAT4 m03=-S01+S03;\n"
" FLOAT4 m13=-S11+S13;\n"
" FLOAT4 m23=-S21+S23;\n"
" FLOAT4 m33=-S31+S33;\n"
" \n"
" //NC4HW4 [alpha*alpha,srcChannelC16,dstHeight,16]\n"
" //index: [0,pos.y/4,pos.x,pos.y % 4]\n"
" int out_offset=((pos.y/4)*dstHeight+pos.x)*16+(pos.y % 4)*4;\n"
" int batch_offset=srcChannelC16*dstHeight*16;\n"
" vstore4(+m00-m20,0,uOutput+out_offset+0*batch_offset);\n"
" vstore4(+(FLOAT)0.5f*m10+(FLOAT)0.5f*m20,0,uOutput+out_offset+1*batch_offset);\n"
" vstore4(-(FLOAT)0.5f*m10+(FLOAT)0.5f*m20,0,uOutput+out_offset+2*batch_offset);\n"
" vstore4(-m10+m30,0,uOutput+out_offset+3*batch_offset);\n"
" vstore4(+m01-m21,0,uOutput+out_offset+4*batch_offset);\n"
" vstore4(+(FLOAT)0.5f*m11+(FLOAT)0.5f*m21,0,uOutput+out_offset+5*batch_offset);\n"
" vstore4(-(FLOAT)0.5f*m11+(FLOAT)0.5f*m21,0,uOutput+out_offset+6*batch_offset);\n"
" vstore4(-m11+m31,0,uOutput+out_offset+7*batch_offset);\n"
" vstore4(+m02-m22,0,uOutput+out_offset+8*batch_offset);\n"
" vstore4(+(FLOAT)0.5f*m12+(FLOAT)0.5f*m22,0,uOutput+out_offset+9*batch_offset);\n"
" vstore4(-(FLOAT)0.5f*m12+(FLOAT)0.5f*m22,0,uOutput+out_offset+10*batch_offset);\n"
" vstore4(-m12+m32,0,uOutput+out_offset+11*batch_offset);\n"
" vstore4(+m03-m23,0,uOutput+out_offset+12*batch_offset);\n"
" vstore4(+(FLOAT)0.5f*m13+(FLOAT)0.5f*m23,0,uOutput+out_offset+13*batch_offset);\n"
" vstore4(-(FLOAT)0.5f*m13+(FLOAT)0.5f*m23,0,uOutput+out_offset+14*batch_offset);\n"
" vstore4(-m13+m33,0,uOutput+out_offset+15*batch_offset);\n"
" }\n"
"}\n"
"__kernel void winoTransDstBuf2_3_1_c16_c4(GLOBAL_SIZE_DIM2\n"
" __global const FLOAT* uInput,\n"
" __global const FLOAT* uBias,\n"
" __global FLOAT* uOutput,\n"
" __private const int unitWidth,//wUnit\n"
" __private const int unitHeight,//hUnit\n"
" __private const int dstWidth,\n"
" __private const int dstHeight,\n"
" __private const int dstChannelC4,__private const int dstChannelC16,__private const int srcWidth,\n"
" __private const int batchOffset,\n"
" __private const int batch,\n"
" __private const int output_pad_left,__private const int output_pad_right) {\n"
" int2 pos=(int2)(get_global_id(0),get_global_id(1));\n"
" UNIFORM_BOUNDRY_CHECK(pos.x,pos.y);\n"
" int unitWidth_idx=pos.x % unitWidth;\n"
" int unitHeight_idx=pos.x/unitWidth;\n"
" int2 realPos=(int2)(unitWidth_idx,unitHeight_idx);\n"
" \n"
" FLOAT4 bias=vload4(0,uBias+pos.y*4);\n"
" {\n"
" int oyStart=realPos.y*2;\n"
" int oxStart=realPos.x*2;\n"
" \n"
" //NC4HW4 [alpha2,dstChannelC16,wUnit*hUnit,16]\n"
" //index: [0,pos.y/4,pos.x,pos.y%4]\n"
" const int inp_offset=((pos.y/4)*srcWidth+pos.x)*16+(pos.y % 4)*4;\n"
" const int ic_offset=16*srcWidth*dstChannelC16;\n"
" FLOAT4 S00=vload4(0,uInput+inp_offset+ic_offset*0);\n"
" FLOAT4 S10=vload4(0,uInput+inp_offset+ic_offset*1);\n"
" FLOAT4 S20=vload4(0,uInput+inp_offset+ic_offset*2);\n"
" FLOAT4 S30=vload4(0,uInput+inp_offset+ic_offset*3);\n"
" FLOAT4 S01=vload4(0,uInput+inp_offset+ic_offset*4);\n"
" FLOAT4 S11=vload4(0,uInput+inp_offset+ic_offset*5);\n"
" FLOAT4 S21=vload4(0,uInput+inp_offset+ic_offset*6);\n"
" FLOAT4 S31=vload4(0,uInput+inp_offset+ic_offset*7);\n"
" FLOAT4 S02=vload4(0,uInput+inp_offset+ic_offset*8);\n"
" FLOAT4 S12=vload4(0,uInput+inp_offset+ic_offset*9);\n"
" FLOAT4 S22=vload4(0,uInput+inp_offset+ic_offset*10);\n"
" FLOAT4 S32=vload4(0,uInput+inp_offset+ic_offset*11);\n"
" FLOAT4 S03=vload4(0,uInput+inp_offset+ic_offset*12);\n"
" FLOAT4 S13=vload4(0,uInput+inp_offset+ic_offset*13);\n"
" FLOAT4 S23=vload4(0,uInput+inp_offset+ic_offset*14);\n"
" FLOAT4 S33=vload4(0,uInput+inp_offset+ic_offset*15);\n"
" FLOAT4 m00=+S00+S01+S02;\n"
" FLOAT4 m10=+S10+S11+S12;\n"
" FLOAT4 m20=+S20+S21+S22;\n"
" FLOAT4 m30=+S30+S31+S32;\n"
" FLOAT4 m01=+S01-S02+S03;\n"
" FLOAT4 m11=+S11-S12+S13;\n"
" FLOAT4 m21=+S21-S22+S23;\n"
" FLOAT4 m31=+S31-S32+S33;\n"
" \n"
" //NC4HW4 [batch,dstChannelC4,dstHeight,dstWidth]\n"
" //index: [batchOffset,pos.y,oyStart,oxStart]\n"
" int out_offset=(((batchOffset+ pos.y*batch)*dstHeight+oyStart)*dstWidth+oxStart)*4;\n"
" {\n"
" int ox=oxStart+0;\n"
" int oy=oyStart+0;\n"
" if (ox<dstWidth && oy<dstHeight) {\n"
" FLOAT4 res=bias+m00+m10+m20;\n"
"#ifdef RELU\n"
" res=max(res,(FLOAT4)(0));\n"
"#endif\n"
"#ifdef RELU6\n"
" res=clamp(res,(FLOAT4)(0),(FLOAT4)(6));\n"
"#endif\n"
" vstore4(res,0,uOutput+out_offset);\n"
" }\n"
" }\n"
" {\n"
" int ox=oxStart+1;\n"
" int oy=oyStart+0;\n"
" if (ox<dstWidth && oy<dstHeight) {\n"
" FLOAT4 res=bias+m10-m20+m30;\n"
"#ifdef RELU\n"
" res=max(res,(FLOAT4)(0));\n"
"#endif\n"
"#ifdef RELU6\n"
" res=clamp(res,(FLOAT4)(0),(FLOAT4)(6));\n"
"#endif\n"
" vstore4(res,0,uOutput+out_offset+4);\n"
" }\n"
" }\n"
" {\n"
" int ox=oxStart+0;\n"
" int oy=oyStart+1;\n"
" if (ox<dstWidth && oy<dstHeight) {\n"
" FLOAT4 res=bias+m01+m11+m21;\n"
"#ifdef RELU\n"
" res=max(res,(FLOAT4)(0));\n"
"#endif\n"
"#ifdef RELU6\n"
" res=clamp(res,(FLOAT4)(0),(FLOAT4)(6));\n"
"#endif\n"
" vstore4(res,0,uOutput+out_offset+4*dstWidth);\n"
" }\n"
" }\n"
" {\n"
" int ox=oxStart+1;\n"
" int oy=oyStart+1;\n"
" if (ox<dstWidth && oy<dstHeight) {\n"
" FLOAT4 res=bias+m11-m21+m31;\n"
"#ifdef RELU\n"
" res=max(res,(FLOAT4)(0));\n"
"#endif\n"
"#ifdef RELU6\n"
" res=clamp(res,(FLOAT4)(0),(FLOAT4)(6));\n"
"#endif\n"
" vstore4(res,0,uOutput+out_offset+4*dstWidth+4);\n"
" }\n"
" }\n"
" }\n"
"}\n"
"__attribute__((intel_reqd_sub_group_size(16)))\n"
"__kernel void gemm_buf_intel(__global const FLOAT* input0,\n"
" __global const FLOAT* input1,\n"
" __global FLOAT* output,\n"
" __private const int width,//ROUND_UP(wUnit*hUnit,8)\n"
" __private const int height,//dstChannelC16\n"
" __private const int srcChannelC16,\n"
" __private const int alpha2) {\n"
" int3 pos=(int3)(get_global_id(0),get_group_id(1),get_global_id(2));\n"
" const int sglid=get_sub_group_local_id();\n"
" const int pos_x=pos.x << 3;\n"
" const int pos_y=pos.y;\n"
" FLOAT8 o=(FLOAT8)(0);\n"
" const int kernel_base=mul24(mul24(mad24(pos.z,height,pos_y),srcChannelC16),256);\n"
" const int inp_base=mul24(mad24(mul24(pos.z,srcChannelC16),width,pos_x),16);\n"
" \n"
" for(int k=0; k<srcChannelC16; ++k){\n"
" \n"
"#ifdef MNN_SUPPORT_FP16\n"
" FLOAT8 wei0=as_half8(intel_sub_group_block_read_us8((__global ushort*)(input1+kernel_base+k*256)));\n"
" FLOAT8 wei1=as_half8(intel_sub_group_block_read_us8((__global ushort*)(input1+kernel_base+k*256+8*16)));\n"
" FLOAT8 s=as_half8(intel_sub_group_block_read_us8((__global ushort*)(input0+inp_base+k*width*16)));\n"
" o=mad(wei0.s0,as_half8(intel_sub_group_shuffle(as_ushort8(s),0)),o);\n"
" o=mad(wei0.s1,as_half8(intel_sub_group_shuffle(as_ushort8(s),1)),o);\n"
" o=mad(wei0.s2,as_half8(intel_sub_group_shuffle(as_ushort8(s),2)),o);\n"
" o=mad(wei0.s3,as_half8(intel_sub_group_shuffle(as_ushort8(s),3)),o);\n"
" o=mad(wei0.s4,as_half8(intel_sub_group_shuffle(as_ushort8(s),4)),o);\n"
" o=mad(wei0.s5,as_half8(intel_sub_group_shuffle(as_ushort8(s),5)),o);\n"
" o=mad(wei0.s6,as_half8(intel_sub_group_shuffle(as_ushort8(s),6)),o);\n"
" o=mad(wei0.s7,as_half8(intel_sub_group_shuffle(as_ushort8(s),7)),o);\n"
" o=mad(wei1.s0,as_half8(intel_sub_group_shuffle(as_ushort8(s),8)),o);\n"
" o=mad(wei1.s1,as_half8(intel_sub_group_shuffle(as_ushort8(s),9)),o);\n"
" o=mad(wei1.s2,as_half8(intel_sub_group_shuffle(as_ushort8(s),10)),o);\n"
" o=mad(wei1.s3,as_half8(intel_sub_group_shuffle(as_ushort8(s),11)),o);\n"
" o=mad(wei1.s4,as_half8(intel_sub_group_shuffle(as_ushort8(s),12)),o);\n"
" o=mad(wei1.s5,as_half8(intel_sub_group_shuffle(as_ushort8(s),13)),o);\n"
" o=mad(wei1.s6,as_half8(intel_sub_group_shuffle(as_ushort8(s),14)),o);\n"
" o=mad(wei1.s7,as_half8(intel_sub_group_shuffle(as_ushort8(s),15)),o);\n"
"#else \n"
" FLOAT8 wei0=as_float8(intel_sub_group_block_read8((__global uint*)(input1+kernel_base+k*256)));\n"
" FLOAT8 wei1=as_float8(intel_sub_group_block_read8((__global uint*)(input1+kernel_base+k*256+8*16)));\n"
" FLOAT8 s=as_float8(intel_sub_group_block_read8((__global uint*)(input0+inp_base+k*width*16)));\n"
" o=mad(wei0.s0,intel_sub_group_shuffle(s,0),o);\n"
" o=mad(wei0.s1,intel_sub_group_shuffle(s,1),o);\n"
" o=mad(wei0.s2,intel_sub_group_shuffle(s,2),o);\n"
" o=mad(wei0.s3,intel_sub_group_shuffle(s,3),o);\n"
" o=mad(wei0.s4,intel_sub_group_shuffle(s,4),o);\n"
" o=mad(wei0.s5,intel_sub_group_shuffle(s,5),o);\n"
" o=mad(wei0.s6,intel_sub_group_shuffle(s,6),o);\n"
" o=mad(wei0.s7,intel_sub_group_shuffle(s,7),o);\n"
" o=mad(wei1.s0,intel_sub_group_shuffle(s,8),o);\n"
" o=mad(wei1.s1,intel_sub_group_shuffle(s,9),o);\n"
" o=mad(wei1.s2,intel_sub_group_shuffle(s,10),o);\n"
" o=mad(wei1.s3,intel_sub_group_shuffle(s,11),o);\n"
" o=mad(wei1.s4,intel_sub_group_shuffle(s,12),o);\n"
" o=mad(wei1.s5,intel_sub_group_shuffle(s,13),o);\n"
" o=mad(wei1.s6,intel_sub_group_shuffle(s,14),o);\n"
" o=mad(wei1.s7,intel_sub_group_shuffle(s,15),o);\n"
"#endif \n"
" }\n"
" int out_offset=mul24(mad24(mad24(pos.z,height,pos_y),width,pos_x),16);\n"
"#ifdef MNN_SUPPORT_FP16\n"
" intel_sub_group_block_write_us8((__global ushort*)(output+out_offset),as_ushort8(o));\n"
"#else\n"
" intel_sub_group_block_write8((__global uint*)(output+out_offset),as_uint8(o));\n"
"#endif\n"
"}\n"
;
#endif
#endif
}
